// ++++++++++++++++++++++++++++++++++++++++
// 《零基础Go语言算法实战》源码
// ++++++++++++++++++++++++++++++++++++++++
// Author:廖显东（ShirDon）
// Blog:https://www.shirdon.com/
// Gitee:https://gitee.com/shirdonl/goAlgorithms.git
// Buy link :https://item.jd.com/14101229.html
// ++++++++++++++++++++++++++++++++++++++++

package main

import (
	"fmt"
	"math"
)

// 逻辑回归结构体，包含权重、学习率和迭代次数
type LogisticRegression struct {
	weights    []float64
	lr         float64
	iterations int
}

// 创建新的逻辑回归对象并返回
func NewLogisticRegression(lr float64, iterations int) *LogisticRegression {
	return &LogisticRegression{lr: lr, iterations: iterations}
}

// 计算z的 sigmoid 函数值并返回
func (l *LogisticRegression) sigmoid(z float64) float64 {
	return 1.0 / (1.0 + math.Exp(-z))
}

// 预测方法，给定输入向量X，预测输出并返回
func (l *LogisticRegression) predict(X []float64) float64 {
	var z float64
	for i, xi := range X {
		z += xi * l.weights[i]
	}
	return l.sigmoid(z)
}

// 训练方法，给定输入矩阵X和输出向量y，训练模型的权重
func (l *LogisticRegression) train(X [][]float64, y []float64) {
	nSamples := len(X)
	nFeatures := len(X[0])
	l.weights = make([]float64, nFeatures)

	// 进行多次迭代，更新权重
	for i := 0; i < l.iterations; i++ {
		for j := 0; j < nSamples; j++ {
			yPred := l.predict(X[j])
			res := y[j] - yPred
			for k := 0; k < nFeatures; k++ {
				l.weights[k] += l.lr * res * X[j][k]
			}
		}
	}
}

// 预测方法，给定输入矩阵X，预测输出并返回
func (l *LogisticRegression) predictAll(X [][]float64) []float64 {
	nSamples := len(X)
	yPred := make([]float64, nSamples)
	for i, xi := range X {
		yPred[i] = l.predict(xi)
	}
	return yPred
}

func main() {
	// 输入矩阵X
	X := [][]float64{
		{1, 2},
		{2, 1},
		{3, 4},
		{4, 3},
	}
	//输出向量y
	y := []float64{0, 0, 1, 1}

	// 创建逻辑回归对象并训练模型
	lr := NewLogisticRegression(0.1, 100)
	lr.train(X, y)

	// 使用训练好的模型进行预测
	yPred := lr.predictAll(X)

	// 输出预测结果
	fmt.Println(yPred)
}

//$ go run logisticRegression.go
//[0.7203202998607646 0.6688645965102862 0.8854875088641785 0.858447853849889]
